from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np;
import time;
from sklearn import preprocessing

def loadFile(filepath):
	file = open(filepath)
	corpus = []
	for line in file:
		corpus.append(line)
	file.close()
	vectorizer = CountVectorizer();
	transformer = TfidfTransformer();
	tfidf = transformer.fit_transform(vectorizer.fit_transform(corpus));
	word = vectorizer.get_feature_names();
	weight = tfidf.toarray();
	D = np.matrix(weight);
	return word,D

def init(D,k):
	m,n = D.shape
	U = np.matrix(np.random.rand(m,k))
	V = np.matrix(np.random.rand(k,n))
	return U,V

def generateLine(m,k,t):
	ll = np.zeros(k);
	for i in range(k):
		if i<=t:
			ll[i] = 1;
		else:
			ll[i] = 1.0/(i-t);
	ll = ll * m / ll.sum();
	return np.matrix(ll);

def updateU(D,U,V,m,lmd1,sparse,P,a):
	if lmd1 > 0 :
		S = np.ones((m,1)) * (U.sum(0)-P);
		R = U - a * (U * (V * V.T) - D * V.T + lmd1 * S + sparse);
	else:
		R = (D * V.T - sparse) * (V * V.T).I;
	R = R.clip(0,np.inf);
	# return np.matrix(preprocessing.normalize(R,norm="l2"));
	return R/R.sum(1)

# def updateV(D,U,V,lmd2,sparse,a):
# 	if lmd2 > 0 and sparse > 0:
# 		R = V - a * (U.T * U * V - U.T * D -lmd2 * (V * V.T).I * V + sparse);
# 	else:
# 		if lmd2 > 0:
# 			R = V - a * (U.T * U * V - U.T * D -lmd2 * (V * V.T).I * V);
# 		else:
# 			if sparse > 0:
# 				R = V - a * (U.T * U * V - U.T * D + sparse);
# 			else:
# 				R = V - a * (U.T * U * V - U.T * D);
# 	R = R.clip(0,np.inf);
# 	# return np.matrix(preprocessing.normalize(R,norm="l2"));
# 	return R/R.sum(1)

def updateV(D,U,V,lmd2,sparse,a):
	if lmd2 > 0 and sparse > 0:
		R = (U.T * U).I * (U.T * D - sparse + lmd2 * (V * V.T).I * V);
	else:
		if lmd2 > 0:
			R = (U.T * U).I * (U.T * D + lmd2 * (V * V.T).I * V);
		else:
			if sparse > 0:
				R = (U.T * U).I * (U.T * D - sparse);
			else:
				R = V - a * (U.T * U * V - U.T * D);
	R = R.clip(0,np.inf);
	# return np.matrix(preprocessing.normalize(R,norm="l2"));
	return R/R.sum(1)

def show_top_terms(W, m, n, k, terms):
	'''displays the top NUMTERMS terms for each cluster, k'''
	for c in range(k):
		# populate a dict with term:membership
		# as a key:value pair for each term in this cluster
		toptermsd = {}
		for t in range(len(W)):
			toptermsd[terms[t]] = W[t][c]

		# sort the terms into a list of tuples, ordered by value
		# (cluster membership)
		topterms = sorted(toptermsd.items(), key=lambda x: x[1])
		toptermsd.clear()
		# print the last NUMTERMS terms
		print("\nTopic %d:" % (c + 1))
		for j in range(1, NUMTERMS + 1):
			print u"%s" % topterms[-j][0]# + u"\t(%.12f)" % (topterms[-j][1])

if __name__ == "__main__":
	filepath = "result/20newsgroup.data"
	word, D = loadFile(filepath)
	m,n = D.shape
	print D.shape
	index = 0
	for item in word:
		print u"%s %d" % item, index
		index = index+1
	k = 100
	NUMTERMS = 20
	lmd1 = 0
	lmd2 = 0.5
	countIter = 100
	sparse1 = 0
	sparse2 = 0
	U,V = init(D,k)
	save_file = "result/JXL_20news"
	P = generateLine(20000,k,20);
	t0 = time.time();
	for i in range(countIter):
		print "iteration %d :%f"%(i,time.time()-t0)
		# a = 1/(i+1)/10
		U = updateU(D,U,V,m,lmd1,sparse1,P,1e-6)
		V = updateV(D,U,V,lmd2,sparse2,1e-3)
		# if(i%20==0):
		# 	print V*V.T
		if((i+1)%100==0):
			# W=np.array(V.T)
			# show_top_terms(W,m,n,k,word)
			np.savetxt(save_file+"topic_"+str(k)+"_iter_"+str(i+1)+"_lmd1_"+str(lmd1)+"_lmd2_"+str(lmd2)+"_sparse1_"+str(sparse1)+"_sparse2_"+str(sparse2)+"-U.txt",U,delimiter=',')
			np.savetxt(save_file+"topic_"+str(k)+"_iter_"+str(i+1)+"_lmd1_"+str(lmd1)+"_lmd2_"+str(lmd2)+"_sparse1_"+str(sparse1)+"_sparse2_"+str(sparse2)+"-V.txt",V,delimiter=',')